{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# 变分自动编码器\n",
    "变分编码器是自动编码器的升级版本，其结构跟自动编码器是类似的，也由编码器和解码器构成。\n",
    "\n",
    "回忆一下，自动编码器有个问题，就是并不能任意生成图片，因为我们没有办法自己去构造隐藏向量，需要通过一张图片输入编码我们才知道得到的隐含向量是什么，这时我们就可以通过变分自动编码器来解决这个问题。\n",
    "\n",
    "其实原理特别简单，只需要在编码过程给它增加一些限制，迫使其生成的隐含向量能够粗略的遵循一个标准正态分布，这就是其与一般的自动编码器最大的不同。\n",
    "\n",
    "这样我们生成一张新图片就很简单了，我们只需要给它一个标准正态分布的随机隐含向量，这样通过解码器就能够生成我们想要的图片，而不需要给它一张原始图片先编码。\n",
    "\n",
    "一般来讲，我们通过 encoder 得到的隐含向量并不是一个标准的正态分布，为了衡量两种分布的相似程度，我们使用 KL divergence，利用其来表示隐含向量与标准正态分布之间差异的 loss，另外一个 loss 仍然使用生成图片与原图片的均方误差来表示。\n",
    "\n",
    "KL divergence 的公式如下\n",
    "\n",
    "$$\n",
    "D{KL} (P || Q) =  \\int_{-\\infty}^{\\infty} p(x) \\log \\frac{p(x)}{q(x)} dx\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 重参数\n",
    "为了避免计算 KL divergence 中的积分，我们使用重参数的技巧，不是每次产生一个隐含向量，而是生成两个向量，一个表示均值，一个表示标准差，这里我们默认编码之后的隐含向量服从一个正态分布的之后，就可以用一个标准正态分布先乘上标准差再加上均值来合成这个正态分布，最后 loss 就是希望这个生成的正态分布能够符合一个标准正态分布，也就是希望均值为 0，方差为 1\n",
    "\n",
    "所以标准的变分自动编码器如下\n",
    "\n",
    "![](https://ws4.sinaimg.cn/large/006tKfTcgy1fn15cq6n7pj30k007t0sv.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "所以最后我们可以将我们的 loss 定义为下面的函数，由均方误差和 KL divergence 求和得到一个总的 loss\n",
    "\n",
    "```\n",
    "def loss_function(recon_x, x, mu, logvar):\n",
    "    \"\"\"\n",
    "    recon_x: generating images\n",
    "    x: origin images\n",
    "    mu: latent mean\n",
    "    logvar: latent log variance\n",
    "    \"\"\"\n",
    "    MSE = reconstruction_function(recon_x, x)\n",
    "    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)\n",
    "    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)\n",
    "    KLD = torch.sum(KLD_element).mul_(-0.5)\n",
    "    # KL divergence\n",
    "    return MSE + KLD\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面我们用 mnist 数据集来简单说明一下变分自动编码器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-01-01T10:41:05.738797Z",
     "start_time": "2018-01-01T10:41:05.215490Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from torchvision.datasets import MNIST\n",
    "from torchvision import transforms as tfs\n",
    "from torchvision.utils import save_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-01-01T10:41:05.769643Z",
     "start_time": "2018-01-01T10:41:05.741302Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "im_tfs = tfs.Compose([\n",
    "    tfs.ToTensor(),\n",
    "    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 标准化\n",
    "])\n",
    "\n",
    "train_set = MNIST('./mnist', transform=im_tfs)\n",
    "train_data = DataLoader(train_set, batch_size=128, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-01-01T10:41:06.397118Z",
     "start_time": "2018-01-01T10:41:06.306479Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class VAE(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(VAE, self).__init__()\n",
    "\n",
    "        self.fc1 = nn.Linear(784, 400)\n",
    "        self.fc21 = nn.Linear(400, 20) # mean\n",
    "        self.fc22 = nn.Linear(400, 20) # var\n",
    "        self.fc3 = nn.Linear(20, 400)\n",
    "        self.fc4 = nn.Linear(400, 784)\n",
    "\n",
    "    def encode(self, x):\n",
    "        h1 = F.relu(self.fc1(x))\n",
    "        return self.fc21(h1), self.fc22(h1)\n",
    "\n",
    "    def reparametrize(self, mu, logvar):\n",
    "        std = logvar.mul(0.5).exp_()\n",
    "        eps = torch.FloatTensor(std.size()).normal_()\n",
    "        if torch.cuda.is_available():\n",
    "            eps = Variable(eps.cuda())\n",
    "        else:\n",
    "            eps = Variable(eps)\n",
    "        return eps.mul(std).add_(mu)\n",
    "\n",
    "    def decode(self, z):\n",
    "        h3 = F.relu(self.fc3(z))\n",
    "        return F.tanh(self.fc4(h3))\n",
    "\n",
    "    def forward(self, x):\n",
    "        mu, logvar = self.encode(x) # 编码\n",
    "        z = self.reparametrize(mu, logvar) # 重新参数化成正态分布\n",
    "        return self.decode(z), mu, logvar # 解码，同时输出均值方差"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-01-01T10:41:10.056600Z",
     "start_time": "2018-01-01T10:41:06.430817Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "net = VAE() # 实例化网络\n",
    "if torch.cuda.is_available():\n",
    "    net = net.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-01-01T10:41:10.409900Z",
     "start_time": "2018-01-01T10:41:10.059597Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "x, _ = train_set[0]\n",
    "x = x.view(x.shape[0], -1)\n",
    "if torch.cuda.is_available():\n",
    "    x = x.cuda()\n",
    "x = Variable(x)\n",
    "_, mu, var = net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-01-01T10:41:29.753678Z",
     "start_time": "2018-01-01T10:41:29.749178Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Variable containing:\n",
      "\n",
      "Columns 0 to 9 \n",
      "-0.0307 -0.1439 -0.0435  0.3472  0.0368 -0.0339  0.0274 -0.5608  0.0280  0.2742\n",
      "\n",
      "Columns 10 to 19 \n",
      "-0.6221 -0.0894 -0.0933  0.4241  0.1611  0.3267  0.5755 -0.0237  0.2714 -0.2806\n",
      "[torch.cuda.FloatTensor of size 1x20 (GPU 0)]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(mu)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到，对于输入，网络可以输出隐含变量的均值和方差，这里的均值方差还没有训练\n",
    "\n",
    "下面开始训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-01-01T10:13:54.560436Z",
     "start_time": "2018-01-01T10:13:54.530108Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "reconstruction_function = nn.MSELoss(size_average=False)\n",
    "\n",
    "def loss_function(recon_x, x, mu, logvar):\n",
    "    \"\"\"\n",
    "    recon_x: generating images\n",
    "    x: origin images\n",
    "    mu: latent mean\n",
    "    logvar: latent log variance\n",
    "    \"\"\"\n",
    "    MSE = reconstruction_function(recon_x, x)\n",
    "    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)\n",
    "    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)\n",
    "    KLD = torch.sum(KLD_element).mul_(-0.5)\n",
    "    # KL divergence\n",
    "    return MSE + KLD\n",
    "\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\n",
    "\n",
    "def to_img(x):\n",
    "    '''\n",
    "    定义一个函数将最后的结果转换回图片\n",
    "    '''\n",
    "    x = 0.5 * (x + 1.)\n",
    "    x = x.clamp(0, 1)\n",
    "    x = x.view(x.shape[0], 1, 28, 28)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-01-01T10:35:01.115877Z",
     "start_time": "2018-01-01T10:13:54.562533Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 20, Loss: 61.5803\n",
      "epoch: 40, Loss: 62.9573\n",
      "epoch: 60, Loss: 63.4285\n",
      "epoch: 80, Loss: 64.7138\n",
      "epoch: 100, Loss: 63.3343\n"
     ]
    }
   ],
   "source": [
    "for e in range(100):\n",
    "    for im, _ in train_data:\n",
    "        im = im.view(im.shape[0], -1)\n",
    "        im = Variable(im)\n",
    "        if torch.cuda.is_available():\n",
    "            im = im.cuda()\n",
    "        recon_im, mu, logvar = net(im)\n",
    "        loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 将 loss 平均\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    if (e + 1) % 20 == 0:\n",
    "        print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.data[0]))\n",
    "        save = to_img(recon_im.cpu().data)\n",
    "        if not os.path.exists('./vae_img'):\n",
    "            os.mkdir('./vae_img')\n",
    "        save_image(save, './vae_img/image_{}.png'.format(e + 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看看使用变分自动编码器得到的结果，可以发现效果比一般的编码器要好很多\n",
    "\n",
    "![](https://ws1.sinaimg.cn/large/006tKfTcgy1fn1ag8832zj306q0a2gmz.jpg)\n",
    "\n",
    "我们可以输出其中的均值看看"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-01-01T10:40:36.481622Z",
     "start_time": "2018-01-01T10:40:36.463332Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "x, _ = train_set[0]\n",
    "x = x.view(x.shape[0], -1)\n",
    "if torch.cuda.is_available():\n",
    "    x = x.cuda()\n",
    "x = Variable(x)\n",
    "_, mu, _ = net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-01-01T10:40:37.490484Z",
     "start_time": "2018-01-01T10:40:37.485127Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Variable containing:\n",
      "\n",
      "Columns 0 to 9 \n",
      " 0.3861  0.5561  1.1995 -1.6773  0.9867  0.1244 -0.3443 -1.6658  1.3332  1.1606\n",
      "\n",
      "Columns 10 to 19 \n",
      " 0.6898  0.3042  2.1044 -2.4588  0.0504  0.9743  1.1136  0.7872 -0.0777  1.6101\n",
      "[torch.cuda.FloatTensor of size 1x20 (GPU 0)]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(mu)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "变分自动编码器虽然比一般的自动编码器效果要好，而且也限制了其输出的编码 (code) 的概率分布，但是它仍然是通过直接计算生成图片和原始图片的均方误差来生成 loss，这个方式并不好，在下一章生成对抗网络中，我们会讲一讲这种方式计算 loss 的局限性，然后会介绍一种新的训练办法，就是通过生成对抗的训练方式来训练网络而不是直接比较两张图片的每个像素点的均方误差"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
